{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "60cba92f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 77, 768])"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "class Embed(torch.nn.Module):\n",
    "\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "\n",
    "        self.embed = torch.nn.Embedding(49408, 768)\n",
    "        self.pos_embed = torch.nn.Embedding(77, 768)\n",
    "\n",
    "        self.register_buffer('pos_ids', torch.arange(77).unsqueeze(dim=0))\n",
    "\n",
    "    def forward(self, input_ids):\n",
    "        #input_ids -> [b, 77]\n",
    "\n",
    "        #[b, 77] -> [b, 77, 768]\n",
    "        embed = self.embed(input_ids)\n",
    "\n",
    "        #[1, 77] -> [1, 77, 768]\n",
    "        pos_embed = self.pos_embed(self.pos_ids)\n",
    "\n",
    "        #[b, 77, 768]\n",
    "        return embed + pos_embed\n",
    "\n",
    "\n",
    "Embed()(torch.ones(2, 77).long()).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "a1820e1d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 77, 768])"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class Atten(torch.nn.Module):\n",
    "\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.q = torch.nn.Linear(768, 768)\n",
    "        self.k = torch.nn.Linear(768, 768)\n",
    "        self.v = torch.nn.Linear(768, 768)\n",
    "        self.out = torch.nn.Linear(768, 768)\n",
    "\n",
    "    def forward(self, x):\n",
    "        #x -> [b, 77, 768]\n",
    "\n",
    "        b = x.shape[0]\n",
    "\n",
    "        #维度不变\n",
    "        #[b, 77, 768]\n",
    "        q = self.q(x) * 0.125\n",
    "        k = self.k(x)\n",
    "        v = self.v(x)\n",
    "\n",
    "        #拆分注意力头\n",
    "        #[b, 77, 768] -> [b, 77, 12, 64] -> [b, 12, 77, 64] -> [b*12, 77, 64]\n",
    "        q = q.reshape(b, 77, 12, 64).transpose(1, 2).reshape(b * 12, 77, 64)\n",
    "        k = k.reshape(b, 77, 12, 64).transpose(1, 2).reshape(b * 12, 77, 64)\n",
    "        v = v.reshape(b, 77, 12, 64).transpose(1, 2).reshape(b * 12, 77, 64)\n",
    "\n",
    "        #计算qk乘积\n",
    "        #[b*12, 77, 64] * [b*12, 64, 77] -> [b*12, 77, 77]\n",
    "        attn = torch.bmm(q, k.transpose(1, 2))\n",
    "\n",
    "        #[b*12, 77, 77] -> [b, 12, 77, 77]\n",
    "        attn = attn.reshape(b, 12, 77, 77)\n",
    "\n",
    "        #覆盖mask\n",
    "        def get_mask(b):\n",
    "            mask = torch.empty(b, 77, 77)\n",
    "\n",
    "            #上三角的部分置为负无穷\n",
    "            mask.fill_(-float('inf'))\n",
    "\n",
    "            #对角线和以下的位置为0\n",
    "            mask.triu_(1)\n",
    "\n",
    "            return mask.unsqueeze(1)\n",
    "\n",
    "        #[b, 12, 77, 77] + [b, 1, 77, 77] -> [b, 12, 77, 77]\n",
    "        attn = attn + get_mask(attn.shape[0]).to(attn.device)\n",
    "\n",
    "        #[b, 12, 77, 77] -> [b*12, 77, 77]\n",
    "        attn = attn.reshape(b * 12, 77, 77)\n",
    "\n",
    "        #计算softmax,被mask的部分值为0\n",
    "        attn = attn.softmax(dim=-1)\n",
    "\n",
    "        #计算和v的乘积\n",
    "        #[b*12, 77, 77] * [b*12, 77, 64] -> [b*12, 77, 64]\n",
    "        attn = torch.bmm(attn, v)\n",
    "\n",
    "        #[b*12, 77, 64] -> [b, 12, 77, 64] -> [b, 77, 12, 64] -> [b, 77, 768]\n",
    "        attn = attn.reshape(b, 12, 77, 64).transpose(1, 2).reshape(b, 77, 768)\n",
    "\n",
    "        #线性输出,维度不变\n",
    "        #[b, 77, 768]\n",
    "        return self.out(attn)\n",
    "\n",
    "\n",
    "Atten()(torch.randn(2, 77, 768)).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "07982501",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 77, 768])"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class ClipEncoder(torch.nn.Module):\n",
    "\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "\n",
    "        self.s1 = torch.nn.Sequential(\n",
    "            torch.nn.LayerNorm(768),\n",
    "            Atten(),\n",
    "        )\n",
    "\n",
    "        self.s2 = torch.nn.Sequential(\n",
    "            torch.nn.LayerNorm(768),\n",
    "            torch.nn.Linear(768, 3072),\n",
    "        )\n",
    "\n",
    "        self.s3 = torch.nn.Linear(3072, 768)\n",
    "\n",
    "    def forward(self, x):\n",
    "        #x -> [2, 77, 768]\n",
    "\n",
    "        #维度不变\n",
    "        #[2, 77, 768]\n",
    "        x = x + self.s1(x)\n",
    "\n",
    "        #[2, 77, 768]\n",
    "        res = x\n",
    "\n",
    "        #[2, 77, 768] -> [2, 77, 3072]\n",
    "        x = self.s2(x)\n",
    "\n",
    "        #维度不变\n",
    "        #[2, 77, 3072]\n",
    "        x = x * (x * 1.702).sigmoid()\n",
    "\n",
    "        #[2, 77, 3072] -> [2, 77, 768]\n",
    "        return res + self.s3(x)\n",
    "\n",
    "\n",
    "ClipEncoder()(torch.randn(2, 77, 768)).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "cc306ffd",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 77, 768])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#经过优化之后的代码量少得吓人...\n",
    "encoder = torch.nn.Sequential(\n",
    "    Embed(),\n",
    "    ClipEncoder(),\n",
    "    ClipEncoder(),\n",
    "    ClipEncoder(),\n",
    "    ClipEncoder(),\n",
    "    ClipEncoder(),\n",
    "    ClipEncoder(),\n",
    "    ClipEncoder(),\n",
    "    ClipEncoder(),\n",
    "    ClipEncoder(),\n",
    "    ClipEncoder(),\n",
    "    ClipEncoder(),\n",
    "    ClipEncoder(),\n",
    "    torch.nn.LayerNorm(768),\n",
    ")\n",
    "\n",
    "# encoder(torch.ones(2, 77).long()).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "73142c37",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from transformers import CLIPTextModel\n",
    "\n",
    "#加载预训练模型的参数\n",
    "params = CLIPTextModel.from_pretrained(\n",
    "    'lansinuote/diffsion_from_scratch.params', subfolder='text_encoder')\n",
    "\n",
    "#词编码\n",
    "encoder[0].embed.load_state_dict(\n",
    "    params.text_model.embeddings.token_embedding.state_dict())\n",
    "\n",
    "#位置编码\n",
    "encoder[0].pos_embed.load_state_dict(\n",
    "    params.text_model.embeddings.position_embedding.state_dict())\n",
    "\n",
    "#12层编码层\n",
    "for i in range(12):\n",
    "\n",
    "    #第一层norm\n",
    "    encoder[i + 1].s1[0].load_state_dict(\n",
    "        params.text_model.encoder.layers[i].layer_norm1.state_dict())\n",
    "\n",
    "    #注意力q矩阵\n",
    "    encoder[i + 1].s1[1].q.load_state_dict(\n",
    "        params.text_model.encoder.layers[i].self_attn.q_proj.state_dict())\n",
    "\n",
    "    #注意力k矩阵\n",
    "    encoder[i + 1].s1[1].k.load_state_dict(\n",
    "        params.text_model.encoder.layers[i].self_attn.k_proj.state_dict())\n",
    "\n",
    "    #注意力v矩阵\n",
    "    encoder[i + 1].s1[1].v.load_state_dict(\n",
    "        params.text_model.encoder.layers[i].self_attn.v_proj.state_dict())\n",
    "\n",
    "    #注意力out\n",
    "    encoder[i + 1].s1[1].out.load_state_dict(\n",
    "        params.text_model.encoder.layers[i].self_attn.out_proj.state_dict())\n",
    "\n",
    "    #第二层norm\n",
    "    encoder[i + 1].s2[0].load_state_dict(\n",
    "        params.text_model.encoder.layers[i].layer_norm2.state_dict())\n",
    "\n",
    "    #mlp第一层fc\n",
    "    encoder[i + 1].s2[1].load_state_dict(\n",
    "        params.text_model.encoder.layers[i].mlp.fc1.state_dict())\n",
    "\n",
    "    #mlp第二层fc\n",
    "    encoder[i + 1].s3.load_state_dict(\n",
    "        params.text_model.encoder.layers[i].mlp.fc2.state_dict())\n",
    "\n",
    "#输出norm\n",
    "encoder[13].load_state_dict(params.text_model.final_layer_norm.state_dict())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "d2bab9bb",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(True)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# a = encoder(torch.arange(77).unsqueeze(dim=0))\n",
    "# b = params(torch.arange(77).unsqueeze(dim=0)).last_hidden_state\n",
    "\n",
    "# (a == b).all()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:pt39]",
   "language": "python",
   "name": "conda-env-pt39-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
